{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nittfHfNvRoS"
      },
      "source": [
        "##### Copyright 2020 Google LLC.\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Ptc9JShvY9v"
      },
      "outputs": [],
      "source": [
        "# Copyright 2020 The Google Research Authors.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sRywXbP2T-bX"
      },
      "source": [
        "# M-layer experiments\n",
        "This notebook trains M-layers on the problems discussed in \"Intelligent Matrix Exponentiation\"."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GDkojBVAeCQs"
      },
      "source": [
        "\n",
        "\n",
        "Running this locally, the  `m_layer` python module should come with the colab and should already be present.\n",
        "\n",
        "The code of the `m_layer` python module can be downloaded from the google-research github repository."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HeXHggYmO1Qh"
      },
      "outputs": [],
      "source": [
        "import os.path\n",
        "if os.path.isfile('m_layer.py'):\n",
        "  from m_layer import MLayer\n",
        "else:\n",
        "  !if ! type \"svn\" \u003e /dev/null; then sudo apt-get install subversion; fi\n",
        "  !svn export https://github.com/google-research/google-research/trunk/m_layer\n",
        "  from m_layer.m_layer import MLayer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1hkfYu6rR6Wm"
      },
      "outputs": [],
      "source": [
        "GLOBAL_SEED = 1\n",
        "import numpy as np\n",
        "np.random.seed(GLOBAL_SEED)\n",
        "import itertools\n",
        "import functools\n",
        "import operator\n",
        "import logging\n",
        "logging.getLogger('tensorflow').disabled = True\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import pylab\n",
        "\n",
        "print(tf.__version__)\n",
        "print(tf.config.experimental.list_physical_devices('GPU'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k1dKgIC1WBNQ"
      },
      "source": [
        "# Generate a spiral and show extrapolation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OS_m1ZmqWF0I",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "SPIRAL_DIM_REP = 10\n",
        "SPIRAL_DIM_MATRIX = 10\n",
        "SPIRAL_LAYER_SIZE = 20\n",
        "SPIRAL_LR = 0.01\n",
        "SPIRAL_EPOCHS = 1000\n",
        "SPIRAL_BATCH_SIZE = 16\n",
        "\n",
        "def spiral_m_layer_model():\n",
        "  return tf.keras.models.Sequential(\n",
        "      [tf.keras.layers.Dense(SPIRAL_DIM_REP,\n",
        "                          input_shape=(2,)),\n",
        "       MLayer(dim_m=SPIRAL_DIM_MATRIX, \n",
        "              with_bias=True, \n",
        "              matrix_squarings_exp=None,\n",
        "              matrix_init='normal'),\n",
        "       tf.keras.layers.ActivityRegularization(l2=1e-3),\n",
        "       tf.keras.layers.Flatten(),\n",
        "       tf.keras.layers.Dense(1, activation='sigmoid')]\n",
        "  )\n",
        "\n",
        "def spiral_dnn_model(activation_type):\n",
        "  return tf.keras.models.Sequential([\n",
        "      tf.keras.layers.Flatten(input_shape=(2,)),\n",
        "      tf.keras.layers.Dense(SPIRAL_LAYER_SIZE,\n",
        "                            activation=activation_type),\n",
        "      tf.keras.layers.Dense(SPIRAL_LAYER_SIZE,\n",
        "                            activation=activation_type),\n",
        "      tf.keras.layers.Dense(1, activation='sigmoid'),\n",
        "   ])\n",
        "\n",
        "def spiral_generate(n_points, noise=0.5, rng=None, extra_rotation=False):\n",
        "  if rng is None:\n",
        "    rng = np.random.RandomState()\n",
        "  if not extra_rotation:\n",
        "    n = np.sqrt(0.001 + (.25)*rng.rand(n_points, 1)) * 6 * (2 * np.pi)\n",
        "  else:\n",
        "    n = np.sqrt((7.0/36)*rng.rand(n_points, 1)+.25) * 6 * (2 * np.pi)\n",
        "  x = 0.5 * (np.sin(n) * n + (2 * rng.rand(n_points, 1) - 1) * noise)\n",
        "  y = 0.5 * (np.cos(n) * n + (2 * rng.rand(n_points, 1) - 1) * noise)\n",
        "  return (np.vstack((np.hstack((x, y)), np.hstack((-x, -y)))),\n",
        "          np.hstack((np.zeros(n_points), np.ones(n_points))))\n",
        "\n",
        "\n",
        "def spiral_run(model_type, fig=None, activation_type=None, ):\n",
        "  if fig is None:\n",
        "    fig = pylab.figure(figsize=(8,8), dpi=144)\n",
        "  model = spiral_dnn_model(activation_type) if model_type==\"dnn\" else\\\n",
        "          spiral_m_layer_model()\n",
        "  x_train, y_train = spiral_generate(1000)\n",
        "  x_test, y_test =  spiral_generate(333, extra_rotation=True)\n",
        "  model.summary()\n",
        "  opt = tf.keras.optimizers.RMSprop(lr=SPIRAL_LR)\n",
        "  model.compile(loss='binary_crossentropy',\n",
        "                optimizer=opt,\n",
        "                metrics=['accuracy'])\n",
        "  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n",
        "      monitor='loss', factor=0.2, patience=5, min_lr=1e-5)\n",
        "  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss',\n",
        "                                                    patience=30,\n",
        "                                                    min_delta=0.0001,\n",
        "                                                    restore_best_weights=True)\n",
        "  result = model.fit(x_train, y_train, epochs=SPIRAL_EPOCHS,\n",
        "            batch_size=SPIRAL_BATCH_SIZE, verbose=2,\n",
        "            callbacks=[reduce_lr, early_stopping])\n",
        "  n_epochs = len(result.history['loss'])\n",
        "  delta = 0.5 ** 3\n",
        "  xs = np.arange(-14, 14.01, delta)\n",
        "  ys = np.arange(-14, 14.01, delta)\n",
        "  num_samples = len(xs)\n",
        "  a = []\n",
        "  for x in xs:\n",
        "    for y in ys:\n",
        "      a.append([x, y])\n",
        "  t_nn_gen = model.predict(np.array(a))\n",
        "  axes = fig.gca()\n",
        "  XX, YY = np.meshgrid(xs, ys)\n",
        "  axes.contourf(XX, YY, np.arcsinh(t_nn_gen.reshape(XX.shape)),\n",
        "             levels=[0.0, 0.5, 1.0],\n",
        "             colors=[(0.41, 0.67, 0.81, 0.2), (0.89, 0.51, 0.41, 0.2)])\n",
        "  axes.contour(XX, YY, np.arcsinh(t_nn_gen.reshape(XX.shape)),\n",
        "             levels=[0.5])\n",
        "  axes.set_aspect(1)\n",
        "  axes.grid()\n",
        "  axes.plot(x_train[y_train==0, 1], x_train[y_train==0, 0], '.', ms = 2,\n",
        "            label='Class 1')\n",
        "  axes.plot(x_train[y_train==1, 1], x_train[y_train==1, 0], '.', ms = 2,\n",
        "            label='Class 2')\n",
        "  plt.plot(x_test[y_test==1, 1], x_test[y_test==1, 0], '.', ms = .5,\n",
        "            label='Class 2')\n",
        "  plt.plot(x_test[y_test==0, 1], x_test[y_test==0, 0], '.', ms = .5,\n",
        "            label='Class 1')\n",
        "\n",
        "  return fig, n_epochs, result.history['loss'][-1]\n",
        "\n",
        "fig, n_epochs, loss = spiral_run('m_layer')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3A_iJ6_Abg74"
      },
      "source": [
        "# Train an M-layer on multivariate polynomials such as the determinant"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I90YlGBDbojG"
      },
      "outputs": [],
      "source": [
        "POLY_BATCH_SIZE = 32\n",
        "POLY_DIM_MATRIX = 8\n",
        "POLY_DIM_INPUT_MATRIX = 3\n",
        "POLY_EPOCHS = 150\n",
        "POLY_SEED = 123\n",
        "POLY_LOW = -1\n",
        "POLY_HIGH = 1\n",
        "POLY_NUM_SAMPLES = 8192\n",
        "POLY_LR = 1e-3\n",
        "POLY_DECAY = 1e-6\n",
        "\n",
        "def poly_get_model():\n",
        "  return tf.keras.models.Sequential(\n",
        "      [tf.keras.layers.Flatten(input_shape=(POLY_DIM_INPUT_MATRIX,\n",
        "                                         POLY_DIM_INPUT_MATRIX)),\n",
        "       MLayer(dim_m=POLY_DIM_MATRIX, matrix_init='normal'),\n",
        "       tf.keras.layers.ActivityRegularization(l2=1e-4),\n",
        "       tf.keras.layers.Flatten(),\n",
        "       tf.keras.layers.Dense(1)]\n",
        "  )\n",
        "\n",
        "\n",
        "def poly_fun(x, permanent=False):\n",
        "  if permanent:\n",
        "    return sum(\n",
        "        functools.reduce(\n",
        "            operator.mul,\n",
        "            (x[i, pi] for i, pi in enumerate(perm)),\n",
        "            1)\n",
        "        for perm in itertools.permutations(range(x.shape[0])))\n",
        "  return np.linalg.det(x)\n",
        "\n",
        "\n",
        "def poly_run(permanent=False):\n",
        "  rng = np.random.RandomState(seed=POLY_SEED)\n",
        "  num_train = POLY_NUM_SAMPLES * 5 // 4\n",
        "  x_train = rng.uniform(size=(num_train, POLY_DIM_INPUT_MATRIX,\n",
        "                              POLY_DIM_INPUT_MATRIX), low=POLY_LOW,\n",
        "                         high=POLY_HIGH)\n",
        "  x_test = rng.uniform(size=(100000, POLY_DIM_INPUT_MATRIX,\n",
        "                             POLY_DIM_INPUT_MATRIX), low=POLY_LOW,\n",
        "                       high=POLY_HIGH)\n",
        "  y_train = np.array([poly_fun(x, permanent=permanent) for x in x_train])\n",
        "  y_test = np.array([poly_fun(x, permanent=permanent) for x in x_test])\n",
        "  model = poly_get_model()\n",
        "  model.summary()\n",
        "  opt = tf.keras.optimizers.RMSprop(lr=POLY_LR, decay=POLY_DECAY)\n",
        "\n",
        "  model.compile(loss='mse', optimizer=opt)\n",
        "  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n",
        "      monitor='val_loss', factor=0.2, patience=5, min_lr=1e-5)\n",
        "  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', \n",
        "                                                    patience=30,\n",
        "                                                    restore_best_weights=True)\n",
        "  model.fit(x_train, y_train, batch_size=POLY_BATCH_SIZE,\n",
        "            epochs=POLY_EPOCHS,\n",
        "            validation_split=0.2,\n",
        "            shuffle=True,\n",
        "            verbose=2,\n",
        "            callbacks=[reduce_lr, early_stopping])\n",
        "  score_train = model.evaluate(x=x_train, y=y_train)\n",
        "  score_test = model.evaluate(x=x_test, y=y_test)\n",
        "\n",
        "  print('Train, range %s - %s: %s' % (POLY_LOW, POLY_HIGH, score_train))\n",
        "  print('Test, range %s - %s: %s' % (POLY_LOW, POLY_HIGH, score_test))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N3tpvQ5Jb4wN"
      },
      "source": [
        "Permanents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uhfD0LE4b8CG"
      },
      "outputs": [],
      "source": [
        "poly_run(permanent=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aKxPzc9pb-OJ"
      },
      "source": [
        "Determinants"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gYSMrS3xcAO9"
      },
      "outputs": [],
      "source": [
        "poly_run(permanent=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5RyUMrZtcPxk"
      },
      "source": [
        "# Train an M-layer on periodic data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bOSM0W14cUzF"
      },
      "outputs": [],
      "source": [
        "PERIODIC_EPOCHS = 1000\n",
        "PERIODIC_BATCH_SIZE = 128\n",
        "PERIODIC_LR = 0.00001\n",
        "PERIODIC_DIM_MATRIX = 10\n",
        "PERIODIC_INIT_SCALE = 0.01\n",
        "PERIODIC_DIAG_INIT = 10\n",
        "PERIODIC_SEED = 123\n",
        "\n",
        "def periodic_matrix_init(shape, rng=None, **kwargs):\n",
        "  if rng is None:\n",
        "    rng = np.random.RandomState()\n",
        "  data = np.float32(rng.normal(loc=0, scale=PERIODIC_INIT_SCALE, size=shape))\n",
        "  for i in range(shape[1]):\n",
        "    data[:, i, i] -= PERIODIC_DIAG_INIT\n",
        "  return data\n",
        "\n",
        "def periodic_get_model(rng=None):\n",
        "  if rng is None:\n",
        "    rng = np.random.RandomState()\n",
        "  return tf.keras.models.Sequential([\n",
        "      tf.keras.layers.Dense(\n",
        "          2, input_shape=(1,),\n",
        "          kernel_initializer=tf.keras.initializers.RandomNormal()),\n",
        "      MLayer(PERIODIC_DIM_MATRIX, with_bias=True, matrix_squarings_exp=None,\n",
        "             matrix_init=lambda shape, **kwargs:\n",
        "             periodic_matrix_init(shape, rng=rng, **kwargs)),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(1)\n",
        "  ])\n",
        "\n",
        "\n",
        "def periodic_dist2(y_true, y_pred):\n",
        "  return tf.nn.l2_loss(y_true - y_pred)\n",
        "\n",
        "def periodic_run(get_model):\n",
        "  rng = np.random.RandomState(seed=PERIODIC_SEED)\n",
        "  # See README file for information about this dataset.\n",
        "  with gfile.Open('daily-min-temperatures.csv', 'r') as f:\n",
        "    data = pd.read_csv(f)\n",
        "  dates = data['Date']\n",
        "  y = data['Temp']\n",
        "  temperatures = data['Temp']\n",
        "  y = list(np.convolve(temperatures - np.mean(temperatures), np.full(7, 1 / 7),\n",
        "                       mode='valid'))\n",
        "  num_train = 9 * len(y) // 10\n",
        "  num_test = len(y) - num_train\n",
        "  x_all = np.arange(len(y)).tolist()\n",
        "  x_train = x_all[:num_train]\n",
        "  y_train = y[:num_train]\n",
        "  x_test = x_all[num_train:]\n",
        "  y_targets = y[num_train:]\n",
        "\n",
        "  model_to_train = get_model(rng=rng)\n",
        "  input = tf.keras.layers.Input(shape=(1,))\n",
        "  output = model_to_train(input)\n",
        "  model = tf.keras.models.Model(inputs=input, outputs=output)\n",
        "\n",
        "  opt = tf.keras.optimizers.RMSprop(lr=PERIODIC_LR, decay=0)\n",
        "  early_stopping = tf.keras.callbacks.EarlyStopping(restore_best_weights=True)\n",
        "  model.compile(\n",
        "      loss='mean_squared_error', optimizer=opt,\n",
        "        metrics=[periodic_dist2])\n",
        "  history = model.fit(x_train, y_train,\n",
        "                      batch_size=PERIODIC_BATCH_SIZE, epochs=PERIODIC_EPOCHS,\n",
        "                      shuffle=True, verbose=1, callbacks=[early_stopping])\n",
        "  y_predictions = model.predict(x_all)\n",
        "\n",
        "  plt.plot(x_train, y_train, linewidth=1, alpha=0.7)\n",
        "  plt.plot(x_test, y_targets, linewidth=1, alpha=0.7)\n",
        "  plt.plot(x_all, y_predictions, color='magenta')\n",
        "  plt.legend(['y_train', 'y_targets', 'y_predictions'])\n",
        "  plt.xlim([0, 3650])\n",
        "  plt.ylabel('Temperature (Celsius)')\n",
        "  plt.grid(True, which='major', axis='both')\n",
        "  plt.grid(True, which='minor', axis='both')\n",
        "  xtick_index = [i for i, date in enumerate(dates) if date.endswith('-01-01')]\n",
        "  plt.xticks(ticks=xtick_index,\n",
        "             labels=[x[:4] for x in dates[xtick_index].to_list()],\n",
        "             rotation=30)\n",
        "  plt.show()\n",
        "\n",
        "periodic_run(periodic_get_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n6ePoKdpWxiC"
      },
      "source": [
        "# Train an M-layer on CIFAR-10\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eZMU20rKW5IY"
      },
      "outputs": [],
      "source": [
        "CIFAR_DIM_REP = 35\n",
        "CIFAR_DIM_MAT = 30\n",
        "CIFAR_LR = 1e-3\n",
        "CIFAR_DECAY = 1e-6\n",
        "CIFAR_MOMENTUM = 0.9\n",
        "CIFAR_BATCH_SIZE = 32\n",
        "CIFAR_EPOCHS = 150\n",
        "CIFAR_NAME = 'cifar10'\n",
        "CIFAR_NUM_CLASSES = 10\n",
        "\n",
        "def cifar_load_dataset():\n",
        "  train = tfds.load(CIFAR_NAME, split='train', with_info=False, batch_size=-1)\n",
        "  test = tfds.load(CIFAR_NAME, split='test', with_info=False, batch_size=-1)\n",
        "  train_np = tfds.as_numpy(train)\n",
        "  test_np = tfds.as_numpy(test)\n",
        "\n",
        "  x_train, y_train = train_np['image'], train_np['label']\n",
        "  x_test, y_test = test_np['image'], test_np['label']\n",
        "  print('x_train shape:', x_train.shape)\n",
        "  print(x_train.shape[0], 'train samples')\n",
        "  print(x_test.shape[0], 'test samples')\n",
        "\n",
        "  y_train = tf.keras.utils.to_categorical(y_train, CIFAR_NUM_CLASSES)\n",
        "  y_test = tf.keras.utils.to_categorical(y_test, CIFAR_NUM_CLASSES)\n",
        "  x_train_range01 = x_train.astype('float32') / 255\n",
        "  x_test_range01 = x_test.astype('float32') / 255\n",
        "\n",
        "  return (x_train_range01, y_train), (x_test_range01, y_test)\n",
        "\n",
        "def cifar_get_model():\n",
        "  return tf.keras.models.Sequential(\n",
        "      [\n",
        "       tf.keras.layers.Flatten(input_shape=(32, 32, 3)),\n",
        "       tf.keras.layers.Dense(CIFAR_DIM_REP),\n",
        "       MLayer(dim_m=CIFAR_DIM_MAT, with_bias=True, matrix_squarings_exp=3),\n",
        "       tf.keras.layers.ActivityRegularization(1e-3),\n",
        "       tf.keras.layers.Flatten(),\n",
        "       tf.keras.layers.Dense(CIFAR_NUM_CLASSES, activation='softmax')\n",
        "       ])\n",
        "\n",
        "def cifar_run():\n",
        "  (x_train, y_train), (x_test, y_test) = cifar_load_dataset()\n",
        "  model = cifar_get_model()\n",
        "  model.summary()\n",
        "  opt = tf.keras.optimizers.SGD(lr=CIFAR_LR, momentum=CIFAR_MOMENTUM,\n",
        "                                   decay=CIFAR_DECAY)\n",
        "\n",
        "  model.compile(loss='categorical_crossentropy', optimizer=opt,\n",
        "                metrics=['accuracy'])\n",
        "\n",
        "  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n",
        "      monitor='val_acc', factor=0.2, patience=5, min_lr=1e-5)\n",
        "  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_acc', \n",
        "                                                    patience=15,\n",
        "                                                    restore_best_weights=True)\n",
        "\n",
        "  history = model.fit(\n",
        "      x_train,\n",
        "      y_train,\n",
        "      batch_size=CIFAR_BATCH_SIZE,\n",
        "      epochs=CIFAR_EPOCHS,\n",
        "      validation_split=0.1,\n",
        "      shuffle=True,\n",
        "      verbose=2,\n",
        "      callbacks=[reduce_lr, early_stopping])\n",
        "\n",
        "  scores = model.evaluate(x_test, y_test, verbose=0)\n",
        "  print('Test loss:', scores[0])\n",
        "  print('Test accuracy:', scores[1])\n",
        "\n",
        "cifar_run()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "nittfHfNvRoS"
      ],
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "M-Layer Experiments",
      "provenance": [
        {
          "file_id": "1Gn-ScBfylFcW9ZskuXhlz5Rf3UR7fQVL",
          "timestamp": 1581696694023
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
